Paper Name: DiffDis: Empowering Generative Diffusion Model with Cross-Modal Discrimination Capability

Link: https://openaccess.thecvf.com/content/ICCV2023/papers/Huang_DiffDis_Empowering_Generative_Diffusion_Model_with_Cross-Modal_Discrimination_Capability_ICCV_2023_paper.pdf

Project Members: Furkan Genç, Barış Sarper Tezcan

In [ ]:
# define the constants 
WIDTH = 512
HEIGHT = 512
LATENTS_WIDTH = WIDTH // 8
LATENTS_HEIGHT = HEIGHT // 8
BATCH_SIZE = 1
root_dir = "../dataset/cc3m/train"

# training parameters
num_train_epochs = 6
Lambda = 1.0
save_steps = 5000

# optimizer parameters
learning_rate = 1e-5
discriminative_learning_rate = 1e-4  # New learning rate for discriminative tasks
adam_beta1 = 0.9
adam_beta2 = 0.999
adam_weight_decay = 1e-4
adam_epsilon = 1e-8

# IMAGE TO TEXT
test_dataset = "CIFAR10"  # Set to "CIFAR100" to use CIFAR100 dataset

# output directory
train_output_dir = "../results/output_1"
test_output_dir = "../results/" + test_dataset
inference_output_dir = "../results/text_to_image/output_1/last"

# Load the models
model_file = "data/v1-5-pruned.ckpt"  
train_unet_file = None  # Set to None to finetune from scratch, if specified, the diffusion model will be loaded from this file
test_unet_file = "../results/output_1/last.pt" 
inference_unet_file = "../results/output_1/last.pt"

# EMA parameters
use_ema = False  # Set to True to use EMA
ema_decay = 0.9999
warmup_steps = 1000

# TEXT TO IMAGE
prompt1 = "A river with boats docked and houses in the background"
prompt2 = "A piece of chocolate swirled cake on a plate"
prompt3 = "A large bed sitting next to a small Christmas Tree surrounded by pictures"
prompt4 = "A bear searching for food near the river"
prompts = [prompt1, prompt2, prompt3, prompt4]
uncond_prompt = ""  # Also known as negative prompt
do_cfg = True
cfg_scale = 3  # min: 1, max: 14
num_samples = 1

# SAMPLER
sampler = "ddpm"
num_inference_steps = 50
seed = 42
In [ ]:
import torch
import torch.nn.functional as F
import os
from tqdm import tqdm
from ddpm import DDPMSampler
from pipeline import get_time_embedding
from dataloader import train_dataloader
import model_loader
import time
from diffusion import TransformerBlock, UNet_Transformer  # Ensure these are correctly imported

import pipeline
from PIL import Image
from pathlib import Path
from transformers import CLIPTokenizer

# Set the device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Load the models
models = model_loader.preload_models_from_standard_weights(model_file, DEVICE)
ddpm = DDPMSampler(generator=None)

if train_unet_file is not None:
    # Load the UNet model
    print(f"Loading UNet model from {train_unet_file}")
    models['diffusion'].load_state_dict(torch.load(train_unet_file)['model_state_dict'])
    if 'best_loss' in torch.load(train_unet_file):
        best_loss = torch.load(train_unet_file)['best_loss']
        best_step = torch.load(train_unet_file)['best_step']
        last_loss = torch.load(train_unet_file)['last_loss']
        last_step = torch.load(train_unet_file)['last_step']
    else:
        best_loss = float('inf')
        best_step = 0
        last_loss = 0.0
        last_step = 0
else:
    best_loss = float('inf')
    best_step = 0
    last_loss = 0.0
    last_step = 0

# TEXT TO IMAGE
tokenizer = CLIPTokenizer("./data/vocab.json", merges_file="./data/merges.txt")

# Disable gradient computations for the models['encoder'], DDPM, and models['clip'] models
for param in models['encoder'].parameters():
    param.requires_grad = False

for param in models['clip'].parameters():
    param.requires_grad = False

# Set the models['encoder'] and models['clip'] to eval mode
models['encoder'].eval()
models['clip'].eval()

# Separate parameters for discriminative tasks
discriminative_params = []
non_discriminative_params = []

for name, param in models['diffusion'].named_parameters():
    if isinstance(getattr(models['diffusion'], name.split('.')[0], None), (TransformerBlock, UNet_Transformer)):
        discriminative_params.append(param)
    else:
        non_discriminative_params.append(param)

# AdamW optimizer with separate learning rates
optimizer = torch.optim.AdamW([
    {'params': non_discriminative_params, 'lr': learning_rate},
    {'params': discriminative_params, 'lr': discriminative_learning_rate}
], betas=(adam_beta1, adam_beta2), weight_decay=adam_weight_decay, eps=adam_epsilon)

if train_unet_file is not None:
    print(f"Loading optimizer state from {train_unet_file}")
    optimizer.load_state_dict(torch.load(train_unet_file)['optimizer_state_dict'])

# Linear warmup scheduler for non-discriminative parameters
def warmup_lr_lambda(current_step: int):
    if current_step < warmup_steps:
        return float(current_step) / float(max(1, warmup_steps))
    return 1.0

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=[
    warmup_lr_lambda,  # Apply warmup for non-discriminative params
    lambda step: 1.0  # Keep constant learning rate for discriminative params
])

# EMA setup
if use_ema:
    ema_unet = torch.optim.swa_utils.AveragedModel(models['diffusion'], avg_fn=lambda averaged_model_parameter, model_parameter, num_averaged: ema_decay * averaged_model_parameter + (1 - ema_decay) * model_parameter)
/home/furkan/miniconda3/envs/DiffDis/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
Loading samples: 5046/5046
Loaded 5046 samples.

In [ ]:
def train(num_train_epochs, device="cuda", save_steps=1000):
    global best_loss, best_step, last_loss, last_step

    if train_unet_file is not None:
        first_epoch = last_step // len(train_dataloader)
        global_step = last_step + 1
    else:
        first_epoch = 0
        global_step = 0

    accumulator = 0

    # Move models to the device
    models['encoder'].to(device)
    models['clip'].to(device)
    models['diffusion'].to(device)
    if use_ema:
        ema_unet.to(device)

    num_train_epochs = tqdm(range(first_epoch, num_train_epochs), desc="Epoch")
    for epoch in num_train_epochs:
        train_loss = 0.0
        num_train_steps = len(train_dataloader)
        for step, batch in enumerate(train_dataloader):
            start_time = time.time()

            # Extract images and texts from batch
            images = batch["pixel_values"]
            texts = batch["input_ids"]

            # Move batch to the device
            images = images.to(device)
            texts = texts.to(device)

            # Encode images to latent space
            encoder_noise = torch.randn(images.shape[0], 4, LATENTS_HEIGHT, LATENTS_WIDTH).to(device)  # Shape (BATCH_SIZE, 4, 32, 32)
            latents = models['encoder'](images, encoder_noise)

            # Sample noise and timesteps for diffusion process
            bsz = latents.shape[0]
            timesteps = torch.randint(0, ddpm.num_train_timesteps, (bsz,), device=latents.device).long()
            text_timesteps = torch.randint(0, ddpm.num_train_timesteps, (bsz,), device=latents.device).long()

            # Add noise to latents and texts
            noisy_latents, image_noise = ddpm.add_noise(latents, timesteps)
            encoder_hidden_states = models['clip'](texts)
            noisy_text_query, text_noise = ddpm.add_noise(encoder_hidden_states, text_timesteps)

            # Get time embeddings
            image_time_embeddings = get_time_embedding(timesteps, is_image=True).to(device)
            text_time_embeddings = get_time_embedding(timesteps, is_image=False).to(device)
            
            # Average and normalize text time embeddings
            average_noisy_text_query = noisy_text_query.mean(dim=1)
            text_query = F.normalize(average_noisy_text_query, p=2, dim=-1)

            # Randomly drop 10% of text and image conditions: Context Free Guidance
            if torch.rand(1).item() < 0.1:
                text_query = torch.zeros_like(text_query)
            if torch.rand(1).item() < 0.1:
                noisy_latents = torch.zeros_like(noisy_latents)

            # Predict the noise residual and compute loss
            image_pred, text_pred = models['diffusion'](noisy_latents, encoder_hidden_states, image_time_embeddings, text_time_embeddings, text_query)
            image_loss = F.mse_loss(image_pred.float(), image_noise.float(), reduction="mean")
            text_loss = F.mse_loss(text_pred.float(), text_query.float(), reduction="mean")

            loss = image_loss + Lambda * text_loss
            train_loss += loss.item()
            accumulator += loss.item()

            # Backpropagate
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            scheduler.step()
            if use_ema:
                ema_unet.update_parameters(models['diffusion'])

            end_time = time.time()

            if train_unet_file is not None and epoch == first_epoch:
                print(f"Step: {step+1+last_step}/{num_train_steps+last_step}   Loss: {loss.item()}   Time: {end_time - start_time}", end="\r")
            else:
                print(f"Step: {step}/{num_train_steps}   Loss: {loss.item()}   Time: {end_time - start_time}", end="\r")

            if global_step % save_steps == 0 and global_step > 0:
                # Check if the current step's loss is the best
                if accumulator / save_steps < best_loss:
                    best_loss = accumulator / save_steps
                    best_step = global_step
                    best_save_path = os.path.join(train_output_dir, "best.pt")
                    if use_ema:
                        torch.save({
                            'model_state_dict': models['diffusion'].state_dict(),
                            'ema_state_dict': ema_unet.state_dict(),
                            'optimizer_state_dict': optimizer.state_dict(),
                            'best_loss': best_loss,
                            'best_step': best_step,
                            'last_loss': accumulator / save_steps,
                            'last_step': global_step
                        }, best_save_path) 
                    else:
                        torch.save({
                            'model_state_dict': models['diffusion'].state_dict(),
                            'optimizer_state_dict': optimizer.state_dict(),
                            'best_loss': best_loss,
                            'best_step': best_step,
                            'last_loss': accumulator / save_steps,
                            'last_step': global_step
                        }, best_save_path)              

                    print(f"\nNew best model saved to {best_save_path} with loss {best_loss}")

                # Save model and optimizer state
                last_save_path = os.path.join(train_output_dir, f"last.pt")
                if use_ema:
                    torch.save({
                        'model_state_dict': models['diffusion'].state_dict(),
                        'ema_state_dict': ema_unet.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'best_loss': best_loss,
                        'best_step': best_step,
                        'last_loss': accumulator / save_steps,
                        'last_step': global_step
                    }, last_save_path)
                else:
                    torch.save({
                        'model_state_dict': models['diffusion'].state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'best_loss': best_loss,
                        'best_step': best_step,
                        'last_loss': accumulator / save_steps,
                        'last_step': global_step
                    }, last_save_path)
                    
                print(f"Saved state to {last_save_path}")

                # Generate samples from the model
                for i, prompt in enumerate(prompts):
                    # Sample images from the model
                    output_image = pipeline.generate(
                        prompt=prompt,
                        uncond_prompt=uncond_prompt,
                        input_image=None,
                        strength=0.9,
                        do_cfg=do_cfg,
                        cfg_scale=cfg_scale,
                        sampler_name=sampler,
                        n_inference_steps=num_inference_steps,
                        seed=seed,
                        models=models,
                        device=DEVICE,
                        idle_device=DEVICE,
                        tokenizer=tokenizer,
                    )

                    # Save the generated image
                    output_image = Image.fromarray(output_image)

                    # Display the generated image
                    display(output_image)

                print(f"\nSaved images for step {global_step}")
                print('Epoch: %d   Step: %d   Loss: %.5f   Best Loss: %.5f   Best Step: %d\n' % (epoch+1, global_step, accumulator / save_steps, best_loss, best_step))

                accumulator = 0.0

            global_step += 1

        print(f"Average loss over epoch: {train_loss / (step + 1)}")
In [ ]:
s = '==> Training starts..'
s += f'\n\nModel file: {model_file}'
s += f'\nUNet file: {train_unet_file}'
s += f'\nBatch size: {BATCH_SIZE}'
s += f'\nWidth: {WIDTH}'
s += f'\nHeight: {HEIGHT}'
s += f'\nLatents width: {LATENTS_WIDTH}'
s += f'\nLatents height: {LATENTS_HEIGHT}'
s += f'\nFirst epoch: {last_step // len(train_dataloader)}'
s += f'\nNumber of training epochs: {num_train_epochs}'
s += f'\nLambda: {Lambda}'
s += f'\nLearning rate: {learning_rate}'
s += f'\nDiscriminative learning rate: {discriminative_learning_rate}'
s += f'\nAdam beta1: {adam_beta1}'
s += f'\nAdam beta2: {adam_beta2}'
s += f'\nAdam weight decay: {adam_weight_decay}'
s += f'\nAdam epsilon: {adam_epsilon}'
s += f'\nUse EMA: {use_ema}'
s += f'\nEMA decay: {ema_decay}'
s += f'\nWarmup steps: {warmup_steps}'
s += f'\nOutput directory: {train_output_dir}'
s += f'\nSave steps: {save_steps}'
s += f'\nDevice: {DEVICE}'
s += f'\nSampler: {sampler}'
s += f'\nNumber of inference steps: {num_inference_steps}'
s += f'\nSeed: {seed}'
for i, prompt in enumerate(prompts):
    s += f'\nPrompt {i + 1}: {prompt}'
s += f'\nUnconditional prompt: {uncond_prompt}'
s += f'\nDo CFG: {do_cfg}'
s += f'\nCFG scale: {cfg_scale}'
s += f'\n\n'
print(s)

# Create the output directory
os.makedirs(train_output_dir, exist_ok=True)

train(num_train_epochs=num_train_epochs, device=DEVICE, save_steps=save_steps)
==> Training starts..

Model file: data/v1-5-pruned.ckpt
UNet file: None
Batch size: 1
Width: 512
Height: 512
Latents width: 64
Latents height: 64
First epoch: 0
Number of training epochs: 6
Lambda: 1.0
Learning rate: 1e-05
Discriminative learning rate: 0.0001
Adam beta1: 0.9
Adam beta2: 0.999
Adam weight decay: 0.0001
Adam epsilon: 1e-08
Use EMA: False
EMA decay: 0.9999
Warmup steps: 1000
Output directory: ../results/output_1
Save steps: 5000
Device: cuda
Sampler: ddpm
Number of inference steps: 50
Seed: 42
Prompt 1: A river with boats docked and houses in the background
Prompt 2: A piece of chocolate swirled cake on a plate
Prompt 3: A large bed sitting next to a small Christmas Tree surrounded by pictures
Prompt 4: A bear searching for food near the river
Unconditional prompt: 
Do CFG: True
CFG scale: 3


Epoch:   0%|          | 0/6 [00:00<?, ?it/s]/home/furkan/miniconda3/envs/DiffDis/lib/python3.12/site-packages/torch/nn/modules/conv.py:456: UserWarning: Applied workaround for CuDNN issue, install nvrtc.so (Triggered internally at /opt/conda/conda-bld/pytorch_1712609048481/work/aten/src/ATen/native/cudnn/Conv_v8.cpp:84.)
  return F.conv2d(input, weight, bias, self.stride,
/home/furkan/CENG796/code/pipeline.py:188: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  timesteps = torch.tensor(timesteps, dtype=torch.float32)[:, None] # convert the batch of timesteps to a 2-D tensor
Step: 5000/5046   Loss: 0.11019457876682281   Time: 0.1516003608703613388
New best model saved to ../results/output_1/best.pt with loss 0.2574170602272265
Saved state to ../results/output_1/last.pt
100%|██████████| 50/50 [00:04<00:00, 11.53it/s]
No description has been provided for this image
100%|██████████| 50/50 [00:04<00:00, 11.81it/s]
No description has been provided for this image
100%|██████████| 50/50 [00:04<00:00, 11.82it/s]
No description has been provided for this image
100%|██████████| 50/50 [00:04<00:00, 11.82it/s]
No description has been provided for this image
Saved images for step 5000
Epoch: 1   Step: 5000   Loss: 0.25742   Best Loss: 0.25742   Best Step: 5000

Step: 5044/5046   Loss: 0.9965981841087341   Time: 0.19762754440307617535
Epoch:  17%|█▋        | 1/6 [25:27<2:07:15, 1527.18s/it]
Average loss over epoch: 0.25731121857589336   Time: 0.1969141960144043
Step: 4954/5046   Loss: 0.17604000866413116   Time: 0.1969144344329834474
New best model saved to ../results/output_1/best.pt with loss 0.2406541142154485
Saved state to ../results/output_1/last.pt
100%|██████████| 50/50 [00:04<00:00, 11.83it/s]
No description has been provided for this image
100%|██████████| 50/50 [00:04<00:00, 11.82it/s]
No description has been provided for this image
100%|██████████| 50/50 [00:04<00:00, 11.81it/s]
No description has been provided for this image
100%|██████████| 50/50 [00:04<00:00, 11.76it/s]
No description has been provided for this image
Saved images for step 10000
Epoch: 2   Step: 10000   Loss: 0.24065   Best Loss: 0.24065   Best Step: 10000

Step: 5044/5046   Loss: 0.08904671669006348   Time: 0.1974637508392334577
Epoch:  33%|███▎      | 2/6 [49:31<1:38:33, 1478.35s/it]
Average loss over epoch: 0.24081591752830642 Time: 0.19686222076416016
Step: 4908/5046   Loss: 0.0033800562378019094   Time: 0.19713997840881348
New best model saved to ../results/output_1/best.pt with loss 0.23109131355108692
Saved state to ../results/output_1/last.pt
100%|██████████| 50/50 [00:04<00:00, 11.83it/s]
No description has been provided for this image
100%|██████████| 50/50 [00:04<00:00, 11.82it/s]
No description has been provided for this image
100%|██████████| 50/50 [00:04<00:00, 11.82it/s]
No description has been provided for this image
100%|██████████| 50/50 [00:04<00:00, 11.82it/s]
No description has been provided for this image
Saved images for step 15000
Epoch: 3   Step: 15000   Loss: 0.23109   Best Loss: 0.23109   Best Step: 15000

Step: 5044/5046   Loss: 0.13101424276828766   Time: 0.196987390518188483
Epoch:  50%|█████     | 3/6 [1:13:07<1:12:29, 1449.88s/it]
Average loss over epoch: 0.2313138278532704  Time: 0.19705843925476074
Saved state to ../results/output_1/last.pt  Time: 0.197257280349731459523
100%|██████████| 50/50 [00:04<00:00, 11.82it/s]
No description has been provided for this image
100%|██████████| 50/50 [00:04<00:00, 11.82it/s]
No description has been provided for this image
100%|██████████| 50/50 [00:04<00:00, 11.81it/s]
No description has been provided for this image
100%|██████████| 50/50 [00:04<00:00, 11.81it/s]
No description has been provided for this image
Saved images for step 20000
Epoch: 4   Step: 20000   Loss: 0.23762   Best Loss: 0.23109   Best Step: 15000

Step: 5044/5046   Loss: 0.03353509679436684   Time: 0.2095785140991211436
Epoch:  67%|██████▋   | 4/6 [1:36:32<47:44, 1432.32s/it]  
Average loss over epoch: 0.23832769447767252 Time: 0.20766687393188477
Saved state to ../results/output_1/last.pt9   Time: 0.1973145008087158282
100%|██████████| 50/50 [00:04<00:00, 11.81it/s]
No description has been provided for this image
100%|██████████| 50/50 [00:04<00:00, 11.81it/s]
No description has been provided for this image
100%|██████████| 50/50 [00:04<00:00, 11.81it/s]
No description has been provided for this image
100%|██████████| 50/50 [00:04<00:00, 11.81it/s]
No description has been provided for this image
Saved images for step 25000
Epoch: 5   Step: 25000   Loss: 0.23724   Best Loss: 0.23109   Best Step: 15000

Step: 5044/5046   Loss: 0.17421704530715942   Time: 0.1988165378570556664
Epoch:  83%|████████▎ | 5/6 [1:59:59<23:43, 1423.16s/it]
Average loss over epoch: 0.23611877438210113  Time: 0.19716739654541016
Saved state to ../results/output_1/last.pt8   Time: 0.1971397399902343833
100%|██████████| 50/50 [00:04<00:00, 11.81it/s]
No description has been provided for this image
100%|██████████| 50/50 [00:04<00:00, 11.81it/s]
No description has been provided for this image
100%|██████████| 50/50 [00:04<00:00, 11.80it/s]
No description has been provided for this image
100%|██████████| 50/50 [00:04<00:00, 11.80it/s]
No description has been provided for this image
Saved images for step 30000
Epoch: 6   Step: 30000   Loss: 0.23115   Best Loss: 0.23109   Best Step: 15000

Step: 5044/5046   Loss: 1.00068199634552   Time: 0.1973047256469726630573
Epoch: 100%|██████████| 6/6 [2:23:25<00:00, 1434.26s/it]
Average loss over epoch: 0.23221419323251438  Time: 0.19747495651245117

In [ ]:
import torch
import torch.nn.functional as F
import os
import model_loader
import time
import pipeline
from PIL import Image
from transformers import CLIPTokenizer
from IPython.display import display

# Set the device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Load the models
models = model_loader.preload_models_from_standard_weights(model_file, DEVICE)

if inference_unet_file is not None:
    # Load the UNet model
    print(f"Loading UNet model from {inference_unet_file}")
    models['diffusion'].load_state_dict(torch.load(inference_unet_file)['model_state_dict'])

# TEXT TO IMAGE
tokenizer = CLIPTokenizer("./data/vocab.json", merges_file="./data/merges.txt")

# Generate samples from the model
for i, prompt in enumerate(prompts):
    for j in range(num_samples):
        start = time.time()

        # Sample images from the model
        output_image = pipeline.generate(
            prompt=prompt,
            uncond_prompt=uncond_prompt,
            input_image=None,
            strength=0.9,
            do_cfg=do_cfg,
            cfg_scale=cfg_scale,
            sampler_name=sampler,
            n_inference_steps=num_inference_steps,
            seed=seed,
            models=models,
            device=DEVICE,
            idle_device=DEVICE,
            tokenizer=tokenizer,
        )

        end = time.time()
        
        print(f"PROMPT {i+1} - SAMPLE {j+1} - TIME: {end - start:.2f}s\n")

        # Save the generated image
        output_image = Image.fromarray(output_image)
        
        # Display the generated image
        display(output_image)
/home/furkan/miniconda3/envs/DiffDis/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
Loading UNet model from ../results/output_1/last.pt
  0%|          | 0/50 [00:00<?, ?it/s]/home/furkan/CENG796/code/pipeline.py:188: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  timesteps = torch.tensor(timesteps, dtype=torch.float32)[:, None] # convert the batch of timesteps to a 2-D tensor
/home/furkan/miniconda3/envs/DiffDis/lib/python3.12/site-packages/torch/nn/modules/conv.py:456: UserWarning: Applied workaround for CuDNN issue, install nvrtc.so (Triggered internally at /opt/conda/conda-bld/pytorch_1712609048481/work/aten/src/ATen/native/cudnn/Conv_v8.cpp:84.)
  return F.conv2d(input, weight, bias, self.stride,
100%|██████████| 50/50 [00:07<00:00,  6.58it/s]
PROMPT 1 - SAMPLE 1 - TIME: 7.85s

No description has been provided for this image
100%|██████████| 50/50 [00:07<00:00,  6.61it/s]
PROMPT 2 - SAMPLE 1 - TIME: 7.77s

No description has been provided for this image
100%|██████████| 50/50 [00:07<00:00,  6.66it/s]
PROMPT 3 - SAMPLE 1 - TIME: 7.71s

No description has been provided for this image
100%|██████████| 50/50 [00:07<00:00,  6.60it/s]
PROMPT 4 - SAMPLE 1 - TIME: 7.78s

No description has been provided for this image
In [ ]:
import torch
import torchvision
from torchvision import transforms
import torch.nn.functional as F
import os
from ddpm import DDPMSampler
from pipeline import get_time_embedding
import model_loader
import time
from transformers import CLIPTokenizer

# Set the device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Load the models
models = model_loader.preload_models_from_standard_weights(model_file, DEVICE)
ddpm = DDPMSampler(generator=None)

if test_unet_file is not None:
    # Load the UNet model
    print(f"Loading UNet model from {test_unet_file}")
    if use_ema:
        models['diffusion'].load_state_dict(torch.load(test_unet_file)['ema_state_dict'])
    else:
        models['diffusion'].load_state_dict(torch.load(test_unet_file)['model_state_dict'])

# TEXT TO IMAGE
tokenizer = CLIPTokenizer("./data/vocab.json", merges_file="./data/merges.txt")

# Set the models['encoder'], models['clip'], models['diffusion'] to eval mode
models['encoder'].eval()
models['clip'].eval()
models['diffusion'].eval()

print("==> Testing starts..")
Loading UNet model from ../results/output_1/last.pt
==> Testing starts..
In [ ]:
def test(device="cuda"):
    # Get the transform for the test data
    transform = transforms.Compose([
        transforms.Resize((WIDTH, HEIGHT), interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ])

    # Load the CIFAR-10 dataset
    if test_dataset == "CIFAR10":
        testset = torchvision.datasets.CIFAR10(
            root='../dataset', train=False, download=True, transform=transform)

    elif test_dataset == "CIFAR100":
        testset = torchvision.datasets.CIFAR100(
            root='../dataset', train=False, download=True, transform=transform)

    print(f"Test dataset: {test_dataset} | Number of test samples: {len(testset)}")

    # Load the test data
    testloader = torch.utils.data.DataLoader(
        testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

    # Move models to the device
    models['encoder'].to(device)
    models['clip'].to(device)
    models['diffusion'].to(device)

    # Define the class names and tokens
    class_names = testset.classes
    class_tokens = []

    # Tokenize class names
    for class_name in class_names:
        # Tokenize text
        tokens = tokenizer.batch_encode_plus(
            [class_name], padding="max_length", max_length=77
        ).input_ids
        tokens = torch.tensor(tokens, dtype=torch.long).squeeze()
        class_tokens.append(tokens)

    # Convert list of class tokens to a tensor
    class_tokens = torch.stack(class_tokens).to(device)
    print(f"Class tokens shape: {class_tokens.shape}")

    # Encode class tokens with the CLIP model
    with torch.no_grad():
        # Encode class tokens
        encoder_hidden_states = models['clip'](class_tokens)

        # Average and normalize class embeddings
        class_embeddings = encoder_hidden_states.mean(dim=1)
        class_embeddings = F.normalize(class_embeddings, p=2, dim=-1)
        print(f"Class embeddings shape: {class_embeddings.shape}\n")
    
    # Start testing
    test_loss = 0.0
    num_test_steps = len(testloader)
    correct_predictions = 0
    total_predictions = 0
    with torch.no_grad():
        for batch_idx, (images, targets) in enumerate(testloader):
            start_time = time.time()

            # Move batch to the device
            images = images.to(device)
            targets = targets.to(device)
            texts = [class_tokens[target] for target in targets]
            
            # Convert list of class tokens to a tensor
            texts = torch.stack(texts).to(device)

            # Encode images to latent space
            encoder_noise = torch.randn(images.shape[0], 4, LATENTS_HEIGHT, LATENTS_WIDTH).to(device)  # Shape (BATCH_SIZE, 4, 32, 32)
            latents = models['encoder'](images, encoder_noise)

            # Sample noise and timesteps for diffusion process
            bsz = latents.shape[0]
            timesteps = torch.randint(0, ddpm.num_train_timesteps, (bsz,), device=latents.device).long()
            text_timesteps = torch.randint(0, ddpm.num_train_timesteps, (bsz,), device=latents.device).long()

            # Add noise to latents and texts
            noisy_latents, image_noise = ddpm.add_noise(latents, timesteps)
            encoder_hidden_states = models['clip'](texts)
            noisy_text_query, text_noise = ddpm.add_noise(encoder_hidden_states, text_timesteps)

            # Get time embeddings
            image_time_embeddings = get_time_embedding(timesteps, is_image=True).to(device)
            text_time_embeddings = get_time_embedding(timesteps, is_image=False).to(device)
            
            # Average and normalize text time embeddings
            average_noisy_text_query = noisy_text_query.mean(dim=1)
            text_query = F.normalize(average_noisy_text_query, p=2, dim=-1)

            # Randomly drop 10% of text and image conditions: Context Free Guidance
            if torch.rand(1).item() < 0.1:
                text_query = torch.zeros_like(text_query)
            if torch.rand(1).item() < 0.1:
                noisy_latents = torch.zeros_like(noisy_latents)

            # Predict the noise residual and compute loss
            _, text_pred = models['diffusion'](noisy_latents, encoder_hidden_states, image_time_embeddings, text_time_embeddings, text_query)
                
            # Calculate loss
            loss = F.mse_loss(text_pred.float(), text_query.float(), reduction="mean")
            test_loss += loss.item()
            
            # Calculate cosine similarity between the generated text query and class embeddings
            similarities = F.cosine_similarity(text_pred.unsqueeze(1), class_embeddings.unsqueeze(0), dim=-1)
            predicted_classes = similarities.argmax(dim=-1)

            # Compare predictions with actual targets
            correct_predictions += (predicted_classes == targets).sum().item()
            total_predictions += targets.size(0)

            end_time = time.time()

            print(f"Batch {batch_idx + 1}/{num_test_steps} | Loss: {loss:.4f} | Time: {end_time - start_time:.2f}s", end="\r")

    # Calculate total accuracy
    accuracy = correct_predictions / total_predictions
    s = f"Accuracy: %.2f%% ({correct_predictions}/{total_predictions})" % (accuracy * 100)
    s += f"\nTest Loss: {test_loss / num_test_steps:.4f}"
    print("\n" + s)
In [ ]:
s = '==> Testing starts..'
s += f'\n\nTest dataset: {test_dataset}'
s += f'\nModel file: {model_file}'
s += f'\nUNet file: {test_unet_file}'
s += f'\nBatch size: {BATCH_SIZE}'
s += f'\nWidth: {WIDTH}'
s += f'\nHeight: {HEIGHT}'
s += f'\nLatents width: {LATENTS_WIDTH}'
s += f'\nLatents height: {LATENTS_HEIGHT}'
s += f'\nNumber of training epochs: {num_train_epochs}'
s += f'\nLambda: {Lambda}'
s += f'\nLearning rate: {learning_rate}'
s += f'\nDiscriminative learning rate: {discriminative_learning_rate}'
s += f'\nAdam beta1: {adam_beta1}'
s += f'\nAdam beta2: {adam_beta2}'
s += f'\nAdam weight decay: {adam_weight_decay}'
s += f'\nAdam epsilon: {adam_epsilon}'
s += f'\nUse EMA: {use_ema}'
s += f'\nEMA decay: {ema_decay}'
s += f'\nWarmup steps: {warmup_steps}'
s += f'\nOutput directory: {test_output_dir}'
s += f'\nSave steps: {save_steps}'
s += f'\nDevice: {DEVICE}'
s += f'\nSampler: {sampler}'
s += f'\nNumber of inference steps: {num_inference_steps}'
s += f'\nSeed: {seed}'
for i, prompt in enumerate(prompts):
    s += f'\nPrompt {i + 1}: {prompt}'
s += f'\nUnconditional prompt: {uncond_prompt}'
s += f'\nDo CFG: {do_cfg}'
s += f'\nCFG scale: {cfg_scale}'
s += f'\n\n'
print(s)

# Test the model on the CIFAR-10 dataset
test(device=DEVICE)
==> Testing starts..

Test dataset: CIFAR10
Model file: data/v1-5-pruned.ckpt
UNet file: ../results/output_1/last.pt
Batch size: 1
Width: 512
Height: 512
Latents width: 64
Latents height: 64
Number of training epochs: 6
Lambda: 1.0
Learning rate: 1e-05
Discriminative learning rate: 0.0001
Adam beta1: 0.9
Adam beta2: 0.999
Adam weight decay: 0.0001
Adam epsilon: 1e-08
Use EMA: False
EMA decay: 0.9999
Warmup steps: 1000
Output directory: ../results/CIFAR10
Save steps: 5000
Device: cuda
Sampler: ddpm
Number of inference steps: 50
Seed: 42
Prompt 1: A river with boats docked and houses in the background
Prompt 2: A piece of chocolate swirled cake on a plate
Prompt 3: A large bed sitting next to a small Christmas Tree surrounded by pictures
Prompt 4: A bear searching for food near the river
Unconditional prompt: 
Do CFG: True
CFG scale: 3


Files already downloaded and verified
Test dataset: CIFAR10 | Number of test samples: 10000
Class tokens shape: torch.Size([10, 77])
Class embeddings shape: torch.Size([10, 768])

Batch 10000/10000 | Loss: 0.0011 | Time: 0.08s
Accuracy: 99.93% (9993/10000)
Test Loss: 0.0007
In [ ]:
# Test the model on the CIFAR-100 dataset
test_dataset = "CIFAR100"
test_output_dir = "../results/" + test_dataset

test(device=DEVICE)
Files already downloaded and verified
Test dataset: CIFAR100 | Number of test samples: 10000
Class tokens shape: torch.Size([100, 77])
Class embeddings shape: torch.Size([100, 768])

Batch 10000/10000 | Loss: 0.0008 | Time: 0.08s
Accuracy: 92.74% (9274/10000)
Test Loss: 0.0008